from scipy import *
from numpy import *
from visual import *
import time


class planeWave:
    def __init__(self, sizeOfLattice, center, elec, direction, mode, coeff):
        self.sizeOfLattice = sizeOfLattice
        self.center = center

        self.elec = elec
        self.direction = direction
        self.mode = mode
        self.coeff = coeff

        self.tonorm = 1.0

    def spinor(self, x, y):
        p = vector(((self.direction+1)%2)*(-1)**self.elec*self.mode*2*pi/self.sizeOfLattice,self.direction*self.mode*2*pi/self.sizeOfLattice,0.0)
        s1upWf = self.tonorm*self.coeff*e**(-1j*dot(p,vector(x,y,0)))
        s1downWf = 0
        s2upWf = self.tonorm*self.coeff*e**(-1j*dot(p,vector(x,y,0)))
        s2downWf = 0
        
        return [s1upWf,s1downWf,s2upWf,s2downWf]

    def setNorm(self,tonorm):
        self.tonorm = tonorm


class Basis:
    def __init__(self, sizeOfLattice, sizeOfBasis, center, sigma):
        self.sizeOfLattice = sizeOfLattice
        self.sizeOfBasis = sizeOfBasis

        self.sigma = sigma
        self.center = center

        self.basis = [None]*2
        for elec in xrange(2):
            self.basis[elec] = [0]*2
            for direction in xrange(2):
                self.basis[elec][direction] = [0]*self.sizeOfBasis
                for mode in xrange(self.sizeOfBasis):
                    coeff = (1/self.sigma)*e**(-mode**2/self.sigma**2)
                    self.basis[elec][direction][mode] = planeWave(sizeOfLattice, center, elec, direction, mode, coeff)

    def getTotalDensity(self,elec,x,y):
        spinorOne = [0.0,0.0,0.0,0.0]
        spinorTwo = [0.0,0.0,0.0,0.0]
        for direction in xrange(2):
            for mode in xrange(self.sizeOfBasis):
                spinorOne += self.basis[elec][direction][mode].spinor(x,y)
                spinorTwo += self.basis[elec][direction][mode].spinor(x,y)
        return dot(conjugate(spinorOne),spinorTwo).real

    def setNorms(self,elec,tonorm):
        for direction in xrange(2):
            for mode in xrange(self.sizeOfBasis):
                self.basis[elec][direction][mode].setNorm(tonorm)


sizeOfLattice = 16
sizeOfBasis = 8

m = 1.0
c = 1.0

center = vector((sizeOfLattice/2.0),(sizeOfLattice/2.0),0)
sigma = sizeOfLattice/6.0


basis = Basis(sizeOfLattice, sizeOfBasis, center, sigma)

for x in xrange(sizeOfLattice):
    for y in range(sizeOfLattice):
        print basis.getTotalDensity(0,x,y)
